from Train import train
from utils import get_dataset
from options import args_parser
import matplotlib.pyplot as plt
import numpy as np
def save(a,name):
    # 保存
    a = np.array(a)
    np.save('{}.npy'.format(name), a)  # 保存为.npy格式
if __name__ == '__main__':

    args = args_parser()
    # 获得数据集
    # def get_dataset(args, iid, dataset):
    train_dataset11, test_dataset11, user_groups11 = get_dataset(args, 1, 'minst')
    # train_dataset01, test_dataset01, user_groups01 = get_dataset(args, 0, 'minst')
    # train_dataset12, test_dataset12, user_groups12 = get_dataset(args, 1, 'cifar')
    # train_dataset02, test_dataset02, user_groups02 = get_dataset(args, 0, 'cifar')
    global_test_acc, train_loss = train(train_dataset11, test_dataset11, user_groups11)

